package JUC.threadPool;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

/**
 * @className: ForkJoinPoolTest
 * @Description:
 * @Author: wangyifei
 * @Date: 2022/9/18 9:57
 */
public class ForkJoinPoolTest {
    private static Logger logger = LoggerFactory.getLogger(ForkJoinPoolTest.class);
    public static int[] array = new int[10000000];
    public static final int MAX_SIZE = 1000 ;
    static{
        Random r = new Random();
        for(int i = 0 ; i < array.length ; i++){
            array[i] = r.nextInt(100);
        }
        System.out.println("stream sum " + Arrays.stream(array).sum());
        System.out.println("stream max " + Arrays.stream(array).max());
    }
    public static void main(String[] args) {
        ForkJoinPool pool = new ForkJoinPool();
        SumTask sum = new SumTask(array , 0 , array.length -1);
        MaxTask maxTask = new MaxTask(array , 0 , array.length -1);
        pool.execute(sum);
        pool.execute(maxTask);
        try {
            System.out.println(sum.get());
            System.out.println(maxTask.get());
        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (ExecutionException e) {
            e.printStackTrace();
        }

    }

     public static class MaxTask extends RecursiveTask<Integer>{
        private int[] array ;
        private int start  ;
        private int end ;
        public MaxTask(int[] array , int start , int end){
            this.array = array ;
            this.start = start ;
            this.end = end ;
        }

         @Override
         protected Integer compute() {
            int max = 0 ;
            if((end - start) <= MAX_SIZE){
                for(int i = start ; i <= end ; i++){
                    if(max < array[i]){
                        max = array[i];
                    }
                }
            }else{
                int mid = start + (end - start)/2;
                MaxTask sum1 = new MaxTask(array , start ,mid);
                MaxTask sum2 = new MaxTask(array , mid + 1 ,end);
                ForkJoinTask<Integer> fork = sum1.fork();
                ForkJoinTask<Integer> fork1 = sum2.fork();
                try {
                    Integer s1 = fork.get();
                    Integer s2 = fork1.get();
                    max = s1 < s2 ? s2:s1 ;
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } catch (ExecutionException e) {
                    e.printStackTrace();
                }
            }
            return max;
         }
     }

    public static class SumTask extends RecursiveTask<Integer>{
        private int[] array ;
        private int start ;
        private int end ;
        public SumTask(int[] array , int start , int end){
            this.array = array ;
            this.start = start ;
            this.end = end ;
        }

        @Override
        protected Integer compute() {
            int sum = 0 ;
            if((end - start) <= MAX_SIZE){
                for(int i = start ; i <= end ; i++){
                    sum += array[i];
                }
            }else{
               int mid = start + ((end - start)/2) ;
               SumTask sumTask = new SumTask(array, start, mid);
               SumTask sumTask1 = new SumTask(array, mid + 1, end);
               ForkJoinTask<Integer> fork = sumTask.fork();
                ForkJoinTask<Integer> fork1 = sumTask1.fork();
                try {
                    sum = fork.get() + fork1.get();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } catch (ExecutionException e) {
                    e.printStackTrace();
                }
            }
            return sum;
        }
    }
}
